import re
import os
import shutil
import subprocess
from pathlib import Path
from typing import Any, Dict, Tuple, Optional

import numpy as np
import openvino as ov
import torch
from tqdm import tqdm
from ultralytics.cfg import get_cfg
from ultralytics.data.converter import coco80_to_coco91_class
from ultralytics.data.utils import check_det_dataset
from ultralytics.engine.validator import BaseValidator as Validator
from ultralytics.models.yolo import YOLO
from ultralytics.utils import DATASETS_DIR
from ultralytics.utils import DEFAULT_CFG
from ultralytics.utils.metrics import ConfusionMatrix

import nncf

ROOT = os.path.dirname(os.path.realpath(__file__))

def validate(model: ov.Model, data_loader: torch.utils.data.DataLoader, validator: Validator, num_samples: int = None) -> Tuple[Dict, int, int]:
    validator.seen = 0
    validator.jdict = []
    validator.stats = []
    validator.confusion_matrix = ConfusionMatrix(nc=validator.nc)
    model.reshape({0: [1, 3, -1, -1]})
    compiled_model = ov.compile_model(model, device_name="CPU")
    output_layer = compiled_model.output(0)
    for batch_i, batch in enumerate(data_loader):
        if num_samples is not None and batch_i == num_samples:
            break
        batch = validator.preprocess(batch)
        preds = torch.from_numpy(compiled_model(batch["img"])[output_layer])
        preds = validator.postprocess(preds)
        validator.update_metrics(preds, batch)
    stats = validator.get_stats()
    return stats, validator.seen, validator.nt_per_class.sum()

def print_statistics(stats: np.ndarray, total_images: int, total_objects: int) -> None:
    mp, mr, map50, mean_ap = (
        stats["metrics/precision(B)"],
        stats["metrics/recall(B)"],
        stats["metrics/mAP50(B)"],
        stats["metrics/mAP50-95(B)"],
    )
    s = ("%20s" + "%12s" * 6) % ("Class", "Images", "Labels", "Precision", "Recall", "mAP@.5", "mAP@.5:.95")
    print(s)
    pf = "%20s" + "%12i" * 2 + "%12.3g" * 4  # print format
    print(pf % ("all", total_images, total_objects, mp, mr, map50, mean_ap))

def prepare_validation(model: YOLO, args: Any) -> Tuple[Validator, torch.utils.data.DataLoader]:
    validator = model.smart_load("validator")(args)
    validator.data = check_det_dataset(args.data)
    # validator.data : {'path': PosixPath('/train_data/egg'), 'train': '/train_data/egg/train/images', 'val': '/train_data/egg/valid/images',
    # 'test': '/train_data/egg/test-dev2017.txt', 'names': {0: 'person'}, 'yaml_file': '/train_data/train.yaml', 'nc': 1}

    dataset = validator.data["val"]
    print(f"Using dataset: {dataset}, DATASETS_DIR: {DATASETS_DIR}")

    # data_loader = validator.get_dataloader(f"{DATASETS_DIR}/coco128", 1)
    data_loader = validator.get_dataloader(dataset, 1)

    validator.is_coco = True
    validator.class_map = coco80_to_coco91_class()
    validator.names = model.model.names
    validator.metrics.names = validator.names
    validator.nc = model.model.model[-1].nc

    return validator, data_loader

def benchmark_performance(model_path: str, config: Any) -> float:
    command = f"benchmark_app -m {model_path} -d CPU -api async -t 30"
    command += f' -shape "[1,3,{config.imgsz},{config.imgsz}]"'
    cmd_output = subprocess.check_output(command, shell=True)  # nosec

    match = re.search(r"Throughput\: (.+?) FPS", str(cmd_output))
    return float(match.group(1))

def prepare_openvino_model(model: YOLO, model_name: str, model_path: str) -> Tuple[ov.Model, Path]:
    ir_model_path = Path(f"{ROOT}/{model_name}_openvino_model/{model_name}.xml")
    if not ir_model_path.exists():
        onnx_model_path = Path(model_path.replace(".pt", ".onnx"))
        if not onnx_model_path.exists():
            model.export(format="onnx", half=False)

        ov.save_model(ov.convert_model(onnx_model_path), ir_model_path)
    return ov.Core().read_model(ir_model_path), ir_model_path

def quantize(model: ov.Model, data_loader: torch.utils.data.DataLoader, validator: Validator) -> ov.Model:
    def transform_fn(data_item: Dict):
        """
        Quantization transform function. Extracts and preprocess input data from dataloader
        item for quantization.
        Parameters:
        data_item: Dict with data item produced by DataLoader during iteration
        Returns:
            input_tensor: Input data for quantization
        """
        input_tensor = validator.preprocess(data_item)["img"].numpy()
        return input_tensor

    quantization_dataset = nncf.Dataset(data_loader, transform_fn)

    quantized_model = nncf.quantize(
        model,
        quantization_dataset,
        preset=nncf.QuantizationPreset.MIXED,
        ignored_scope=nncf.IgnoredScope(
            types=["Multiply", "Subtract", "Sigmoid"],
            subgraphs=[
                nncf.Subgraph(
                    inputs=["/model.22/Concat", "/model.22/Concat_1", "/model.22/Concat_2"],
                    outputs=["output0/sink_port_0"],
                )
            ],
        ),
    )
    return quantized_model

def main(model_path: str, data_config: str) -> Tuple[float, float, float, float]:
    model_name = Path(model_path).stem

    model = YOLO(model_path)
    args = get_cfg(cfg=DEFAULT_CFG)
    args.data = data_config

    # Prepare validation dataset and helper
    validator, data_loader = prepare_validation(model, args)

    # Convert to OpenVINO model
    ov_model, ov_model_path = prepare_openvino_model(model, model_name, model_path)

    # Quantize mode in OpenVINO representation
    quantized_model = quantize(ov_model, data_loader, validator)
    quantized_model_path = Path(f"{ROOT}/{model_name}_openvino_model/{model_name}_quantized.xml")
    ov.save_model(quantized_model, str(quantized_model_path))

    # Validate FP32 model
    fp_stats, total_images, total_objects = validate(ov_model, tqdm(data_loader), validator)
    print("Floating-point model validation results:")
    print_statistics(fp_stats, total_images, total_objects)

    # Validate quantized model
    q_stats, total_images, total_objects = validate(quantized_model, tqdm(data_loader), validator)
    print("Quantized model validation results:")
    print_statistics(q_stats, total_images, total_objects)

    # Benchmark performance of FP32 model
    fp_model_perf = benchmark_performance(ov_model_path, args)
    print(f"Floating-point model performance: {fp_model_perf} FPS")

    # Benchmark performance of quantized model
    quantized_model_perf = benchmark_performance(quantized_model_path, args)
    print(f"Quantized model performance: {quantized_model_perf} FPS")

    # 打印量化结束，删除当前的runs文件夹
    shutil.rmtree(f"{ROOT}/runs")
    print("Quantization finished")

    return fp_stats["metrics/mAP50-95(B)"], q_stats["metrics/mAP50-95(B)"], fp_model_perf, quantized_model_perf

if __name__ == "__main__":
    # 从命令行参数获取模型路径和数据配置文件
    import sys
    if len(sys.argv) != 3:
        print("Usage: python script.py <model_path> <config>")
        sys.exit(1)

    model_path = sys.argv[1]
    data_config = sys.argv[2]
    main(model_path, data_config)
